from typing import Tuple, Union

import numpy as np

from .utils import (Array, QuantileFunction, SecondQuantileFunction,
                    num_integrate_func)


class MRM:

    def __init__(self,
                 scores: Array,
                 p_default: float = 0.1,
                 dp: float = 0.001) -> None:
        # Check and cache inputs
        self.scores = np.array(scores)
        if self.scores.ndim != 1:
            raise ValueError("Scores should be a 1D array")
        self.dp = dp
        self.p_default = p_default

        # Cache some computations
        self.quantile_function = QuantileFunction(self.scores)
        self.second_quantile_function = SecondQuantileFunction(self.scores)
        self.h_X = lambda p: self.mean * p - self.second_quantile_function(p)
        self._gamma_X = lambda p: 2 * num_integrate_func(
            self.h_X, 0.0, 1.0, self.dp)(p)

    @property
    def mean(self) -> float:
        """Mean of the scores
        """
        return np.mean(self.scores).item()

    @property
    def delta(self) -> float:
        r"""Absolute semideviation
            computed as $E[\max(\mu_X - X, 0)]$
        """
        d = self.mean - self.scores
        f = d * (d > 0)
        return np.mean(f).item()

    @property
    def sigma(self) -> float:
        r"""Standard semideviation
            computed as $\sqrt{E[\max(\mu_X - X, 0)^2]}$
        """
        d = self.mean - self.scores
        f = d * (d > 0)
        return np.sqrt(np.mean(f**2)).item()

    @property
    def gamma_X(self) -> float:
        r"""Double area of dual dispersion space
            defined as $\Gamma_X = 2 \Int_0^1 (\mu_X - F_X^{(-2)}(p)) dp$
        """
        return self._gamma_X(1.0)

    def h_X_p(self, p: Union[float, Array]) -> np.ndarray:
        r"""Dual dispersion space
            defined as $h_X(p) / p$, where $h_X(p) = \mu_X p - F_X^{(-2)}(p)$
        """
        if isinstance(p, list):
            p = np.array(p)
        return self.h_X(p) / p

    def gini_X_p(self, p: Union[float, Array]) -> np.ndarray:
        r"""Tail Gini's measure
            defined as $G_X(p) = \frac{2}{p^2} \Int_0^p (\mu_X\alpha - F_X^{(-2)}(\alpha)) d\alpha$
        """
        if isinstance(p, list):
            p = np.array(p)
        return self._gamma_X(p) / p**2

    def ntvar_X_p(self, p: Union[float, Array]) -> np.ndarray:
        r"""Negative tail value at risk, -TVaR_X(p), 
            defined as $TVaR_X(p) = F_X^{(-2)}(p) / p$ 
        """
        if isinstance(p, list):
            p = np.array(p)
        return -self.second_quantile_function(p) / p

    def compute(self) -> Tuple[dict, dict]:
        """Compute all the scores from the mean-risk models
        """
        risk_models = [
            'delta', 'sigma', 'gamma_X', 'h_X_p', 'gini_X_p', 'ntvar_X_p'
        ]

        scores = {}
        info = {'mean': self.mean, 'p_eval': self.p_default}
        for model_name in risk_models:
            risk = getattr(self, model_name)
            if callable(risk):
                risk = risk(self.p_default)
            info[model_name] = risk
            scores["mean-" + model_name] = self.mean - risk

        return scores, info
